# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Code for clustering strings by edit distance.

Includes exact and approximate strategies for clustering strings ("sequences")
of the same length based on Hamming distance. Developed for clustering DNA
sequences, but should be able to handle clustering for arbitrary strings.

To constrain the nearest neighbor search necessary for clustering, we use exact
and randomized approaches based on Locality Sensitive Hashing. Somewhat similar
approaches have been described in the research literature [1]. Assuming clusters
of constant size, this allows our clustering algorithms to run in linear time as
the number of sequences increases.

Typical usage example:

  >>> sequences = ['AAA', 'ATA', 'GGG']
  >>> clustering.cluster_by_edit_distance(sequences, edit_distance=1)
  [0, 0, 1]

References:
  [1] http://www.ncbi.nlm.nih.gov/pmc/articles/PMC4281958/
"""

import collections
import itertools
import logging
import math
import random


import Levenshtein
import numpy
import six

# Google internal
import gfile
import results_pb2
import sstable


_EMPTY_LIST = []


class AbstractMatcher:
  """Abstract base class for fast neighbor matching."""

  def match(self, sequence):
    """Return sequences in the vicinity of the given sequence.

    Beware: this method is only guaranteed to work for sequences that were used
    to initialize the AbstractMatcher. This is not checked.

    Args:
      sequence: string for which to lookup neighbors.

    Returns:
      Iterable of strings giving all sequences in the vicinity of the given
      sequence. This may include false positives, depending on the distance
      metric.
    """
    raise NotImplementedError


class _IntegerEncoder:
  """Build a encoding of the given keys as integers.

  Attributes:
    key_to_id: dict mapping keys to encoded integers.
    id_to_key: list mapping integer IDs to decoded keys.
  """

  def __init__(self):
    self.key_to_id = {}
    self.id_to_key = []

  def __getitem__(self, key):
    """Lookup the ID corresponding to the given key.

    If key does not yet have an ID, assign the next available integer ID.

    Args:
      key: hashable value.

    Returns:
      Integer ID.
    """
    try:
      return self.key_to_id[key]
    except KeyError:
      identifier = len(self.key_to_id)
      self.key_to_id[key] = identifier
      self.id_to_key.append(key)
      return identifier


class ScaMMatcher(AbstractMatcher):
  """Matcher that uses pre-computed lookup tables generated by ScaM."""

  def __init__(self, table, dtype='u4'):
    """Initialize as ScaMMatcher.

    Args:
      table: Mapping[str, result_pb2.NearestNeighbor] mapping each sequence to
        all of its (approximate) neighbors within some fixed edit distance.
      dtype: optional object convertable to numpy.dtype to use for storing
        positive integer IDs.

    Raises:
      ValueError: if dtype was not big enough.
    """
    neighbors = {}
    encoder = _IntegerEncoder()
    n_entries = len(table)

    for n, (sequence, value) in enumerate(table.items()):
      if n % 1000000 == 0 or (n < 1000000 and n % 100000 == 0):
        logging.info('loading ScaM results %r/%r', n, n_entries)

      neighbor_sequences = [neighbor.docid for neighbor in value.neighbor
                            if neighbor.docid != value.docid]
      if neighbor_sequences:
        neighbor_ids = numpy.array(
            [encoder[seq] for seq in neighbor_sequences], dtype=dtype)
        neighbors[encoder[sequence]] = neighbor_ids

    logging.info('finished loading ScaM results')

    if len(encoder.key_to_id) > numpy.iinfo(dtype).max:
      raise ValueError('ran out of integer IDs')

    self._neighbors = neighbors
    self._sequence_to_id = encoder.key_to_id
    self._id_to_sequence = encoder.id_to_key

  @classmethod
  def from_path(cls, pattern):
    """Create a ScaMMatcher from SSTables of ScaM NearestNeighbors results.

    Args:
      pattern: string pattern for paths to sstables holding output from the ScaM
        map-reduce.

    Returns:
      ScaMMatcher for doing lookups with these pre-computed neighbors.
    """
    paths = sorted(gfile.Glob(pattern))
    wrapper = sstable.TableWrapper(results_pb2.NearestNeighbors.FromString)
    table = sstable.ShardedSSTable(paths, wrapper=wrapper)
    return cls(table)

  def match(self, sequence):
    """See base class."""
    try:
      sequence_id = self._sequence_to_id[sequence]
    except KeyError:
      return _EMPTY_LIST
    else:
      return [self._id_to_sequence[id_] for id_ in self._neighbors[sequence_id]]


class HashMatcher(AbstractMatcher):
  """Match sequences using a hash table and a single hash function."""

  def __init__(self, sequences, hash_func, max_shift=0):
    """Initialize a HashMatcher.

    Args:
      sequences: a sequence of strings to match.
      hash_func: callable that maps a sequence to key corresponding to a hash
        bucket.
      max_shift: optional integer giving the maximum number of positional shifts
        to consider when partitioning the sequences.
    """
    buckets = {}
    for seq in sequences:
      for shift in range(max_shift + 1):
        shifted_seq = seq[shift:] + seq[:shift]
        key = hash_func(shifted_seq)
        # this is faster than using collections.defaultdict(list)
        if key in buckets:
          buckets[key].append(seq)
        else:
          buckets[key] = [seq]
    # filter out length one buckets to reduce memory requirements
    self._buckets = {k: v for k, v in buckets.items() if len(v) > 1}
    self._hash = hash_func

  def match(self, sequence):
    """See base class."""
    key = self._hash(sequence)
    # this lets us drop keys with only a single element
    return self._buckets.get(key, _EMPTY_LIST)


class LSHMatcher(AbstractMatcher):
  """Match sequences using Locality Sensitive Hashing."""

  def __init__(self, sequences, hash_functions, max_shift=0):
    """Initialize a LSHMatcher.

    Args:
      sequences: a sequence of sequences to partition.
      hash_functions: sequence of hash functions (callables) to use for
        partitioning sequences.
      max_shift: optional integer giving the maximum number of positional shifts
        to consider when partitioning.
    """
    self._partitions = [HashMatcher(sequences, func, max_shift)
                        for func in hash_functions]

  def match(self, sequence):
    """See base class."""
    neighbors = set()
    for group in self._partitions:
      neighbors.update(group.match(sequence))
    return neighbors


def exact_lsh_matches(sequences, edit_distance, measure='levenshtein',
                      target_occupancy=0.5, num_choices=4):
  """Build a callable for 'exact' matching of sequences using LSH.

  These matches are proven to be exact for Hamming distance, but we're not quite
  sure it's correct for Levenshtein distance.

  Args:
    sequences: sequence of strings to partition.
    edit_distance: maximum edit distance between any sequence and the closest
      other sequence in the same cluster.
    measure: optional string 'levenshtein' or 'hamming', indicating how to
      calculate edit distance.
    target_occupancy: float indicating the maximum acceptable average number
      of randomly chosen sequences that would appear in the same hash bucket
      by chance.
    num_choices: optional integer giving the number of valid sequence
      elements. Default value is 4, corresponding to the four base pairs in
      DNA.

  Returns:
    Callable for finding matches.
  """
  sequence_length = _unique_length(sequences)
  max_shift = _max_shift(edit_distance, measure)
  min_hash_length = optimal_hash_length(
      len(sequences), max_shift, target_occupancy, num_choices)
  segment_count = _required_segment_count(
      sequence_length, min_hash_length, edit_distance)
  all_segments = itertools.combinations(
      list(range(segment_count)), segment_count - edit_distance)
  hashes = (segmented_hash(segments, segment_count, sequence_length)
            for segments in all_segments)
  return LSHMatcher(sequences, hashes, max_shift).match


def approximate_lsh_matches(sequences, edit_distance, measure='levenshtein',
                            hash_length=10, num_rounds=10, seed=None):
  """Build a callable for matching of sequences using LSH.

  This approach was designed for Hamming distance. It may perform very poorly
  for Levenshtein distance.

  TODO(shoyer): refactor this API to take a desired success_probability
  instead of this lower level API.

  Args:
    sequences: sequence of strings to partition.
    edit_distance: maximum edit distance between any sequence and the closest
      other sequence in the same cluster.
    measure: optional string 'levenshtein' or 'hamming', indicating how to
      calculate edit distance.
    hash_length: integer number of bases from each sequence to use in the hash
      key.
    num_rounds: integer number of random partitions to create.
    seed: optional hashable random seed to guarantee reproducible results when
      calling the `partition` method.

  Returns:
    Callable for finding matches.
  """
  sequence_length = _unique_length(sequences)
  rand = random.Random(seed)
  hashes = (random_hash(hash_length, sequence_length, rand.random())
            for _ in range(num_rounds))
  max_shift = _max_shift(edit_distance, measure)
  return LSHMatcher(sequences, hashes, max_shift).match


def _unique_length(elements):
  """Calculate the unique length of the provided elements.

  Args:
    elements: iterable of objects with a defined length.

  Returns:
    Integer unique length.

  Raises:
    ValueError: if there is no unique length, or if the iterable is empty.
  """
  lengths = set(len(elem) for elem in elements)
  if not lengths:
    raise ValueError('no sequences provided')
  length = lengths.pop()
  if lengths:
    raise ValueError('sequences to cluster must have a unique length')
  return length


def optimal_hash_length(num_sequences, max_shift=0, target_occupancy=0.5,
                        num_choices=4):
  """Calculate the shortest hash length such that random collisions are rare.

  Args:
    num_sequences: integer number of distinct sequences to partition.
    max_shift: optional integer giving the maximum number of positional shifts
      to consider when partitioning.
    target_occupancy: float indicating the maximum acceptable average number of
      randomly chosen sequences that would appear in the same hash bucket by
      chance.
    num_choices: optional integer giving the number of valid sequence elements.
      Default value is 4, corresponding to the four base pairs in DNA.

  Returns:
    Integer giving the optimal hash length.
  """
  return int(math.ceil(math.log(float(num_sequences * (max_shift + 1))
                                / target_occupancy)
                       / math.log(num_choices)))


def _required_segment_count(sequence_length, min_hash_length, edit_distance):
  # The length of each hash, which is constructed from
  # (segment_count - edit_distance) out of segement_count segments, should be at
  # least min_hash_length. This means that we need to satisfy:
  #   sequence_length * (segment_count - edit_distance) / segment_count
  #     >= min_hash_length
  #
  # From some algebra, it follows that segment_count should be given by:
  return max(int(math.ceil(edit_distance /
                           (1 - float(min_hash_length) / sequence_length))), 1)


def segmented_hash(segments, segment_count, sequence_length):
  """Create a hash function for segmented partitioning.

  Args:
    segments: tuple of integers indicating segments to include in the hash.
    segment_count: integer indicating the total number of segments.
    sequence_length: integer length of all sequences in this partition.

  Returns:
    Hash function suitable for partitioning sequences with an ExactStrategy.
  """
  segment_length = int(sequence_length / segment_count)
  starts = [segment * segment_length for segment in segments]
  stops = [start + segment_length for start in starts]
  slices = [slice(start, stop) for start, stop in zip(starts, stops)]

  def hash_func(sequence):
    return ''.join(sequence[sl] for sl in slices)

  return hash_func


def random_hash(hash_length, sequence_length, seed=None):
  """Create a hash function for randomized partitioning.

  Args:
    hash_length: integer number of bases from each sequence to use in the hash
      key.
    sequence_length: integer length of all sequences in this partition.
    seed: optional hashable random seed to guarantee reproducible results.

  Returns:
    Hash function suitable for partitioning sequences with an
    ApproximateStrategy.
  """
  rand = random.Random(seed)
  indices = sorted(rand.sample(range(sequence_length), hash_length))

  def hash_func(sequence):
    return ''.join(sequence[idx] for idx in indices)

  return hash_func


def _max_shift(edit_distance, measure='levenshtein'):
  """Determine the maximum number of shifted positions at an edit distance.

  Args:
    edit_distance: maximum edit distance between sequences.
    measure: optional string 'levenshtein' or 'hamming', indicating how to
      calculate edit distance.

  Returns:
    Integer giving the largest number of posible shifts in position between the
    two sequences at the given edit distance.

  Raises:
    ValueError: if `measure` is invalid.
  """
  if measure == 'levenshtein':
    # as long as all sequences are restricted to have the same length, it
    # requires two edits (one insertion and one deletion) shift bases by one
    # position
    max_shift = edit_distance // 2
  elif measure == 'hamming':
    # shifting isn't necessary for Hamming distance
    max_shift = 0
  else:
    raise ValueError('unexpected measure %r' % measure)
  return max_shift


def hamming_distance(seq1, seq2):
  """Compute the Hamming distance between two sequences.

  This is the edit distance from s1 to s2, ignoring insertions/deletions.

  Args:
    seq1: sequence 1.
    seq2: sequence 2.

  Returns:
    An integer giving the Hamming distance.
  """
  return Levenshtein.hamming(seq1, seq2)


def levenshtein_distance(seq1, seq2):
  """Compute the Levenshtein distance between two sequences.

  This is the edit distance from s1 to s2, allowing for insertions/deletions.

  Args:
    seq1: sequence 1.
    seq2: sequence 2.

  Returns:
    An integer giving the Levenshtein distance.
  """
  seq1 = six.ensure_str(seq1)
  seq2 = six.ensure_str(seq2)
  return Levenshtein.distance(seq1, seq2)


def explore_cluster(seed_sequence, edit_distance, measure, find_nearby):
  """Returns all sequences in the same cluster as the provided sequence.

  Args:
    seed_sequence: string around which to find neighbors.
    edit_distance: maximum edit distance between any sequence and the closest
      other sequence in the same cluster.
    measure: string 'levenshtein' or 'hamming', indicating how to calculate edit
      distance.
    find_nearby: callable that returns an iterable of all nearby sequences to a
      given sequence. These sequences are checked to see if they fall within the
      distance threshold.

  Returns:
    The set of all sequences in the same cluster as seed_sequence, including the
    seed sequence itself.

  Raises:
    ValueError: if `measure` is invalid.
  """
  if measure == 'levenshtein':
    calc_distance = levenshtein_distance
  elif measure == 'hamming':
    calc_distance = hamming_distance
  else:
    raise ValueError('unexpected measure %r' % measure)

  # depth first search
  cluster = set()
  not_yet_explored = set([seed_sequence])

  while not_yet_explored:
    sequence = not_yet_explored.pop()
    cluster.add(sequence)

    for candidate in find_nearby(sequence):
      if (candidate not in cluster and
          candidate not in not_yet_explored and
          calc_distance(candidate, sequence) <= edit_distance):
        not_yet_explored.add(candidate)

  return cluster


def cluster_by_edit_distance(sequences, edit_distance, measure='levenshtein',
                             find_nearby=None):
  """Cluster strings by edit distance.

  Every sequence in a cluster can be reached from other sequence in the same
  cluster by taking steps no larger than the provided Hamming distance.

  Args:
    sequences: a sequence of strings with the same length, e.g., ['ATGC',
      'AATT', 'GCCG'].
    edit_distance: maximum edit distance between any sequence and the closest
      other sequence in the same cluster.
    measure: optional string 'levenshtein' or 'hamming', indicating how to
      calculate edit distance.
    find_nearby: optional callable that returns an iterable of all nearby
      sequences to a given sequence. These sequences are checked to see if they
      fall within the distance threshold. By default, returns a callable created
      by calling `exact_lsh_matches` on the provided sequences.

  Returns:
    A list of integer cluster IDs corresponding to each sequence. Cluster IDs
    are sequential integers starting from one and sorted by order of
    appearance.

  Raises:
    ValueError: if not all sequences have the same length.
  """

  if find_nearby is None:
    find_nearby = exact_lsh_matches(sequences, edit_distance, measure)

  # start from 1 so a default value of 0 corresponds to unclustered.
  cluster_idx = 1
  cluster_assignments = {}  # {sequence: cluster_id}

  for i, seed in enumerate(sequences):
    if seed not in cluster_assignments:
      for seq in explore_cluster(seed, edit_distance, measure, find_nearby):
        cluster_assignments[seq] = cluster_idx
      cluster_idx += 1
    if i % 1000000 == 0 or (i < 1000000 and i % 100000 == 0):
      logging.info('assigned clusters to first %r/%r sequences', i,
                   len(sequences))

  return [cluster_assignments[seq] for seq in sequences]


def set_of_cluster_sets(sequences, clusters):
  """Convert lists of sequences and cluster assignments to a set of sets.

  This provides an order-invariant way to compare cluster assignments.

  Args:
    sequences: list of strings indicating sequences.
    clusters: list of integers indicating cluster assignments.

  Returns:
    frozenset of frozensets of clustered sequences.

  Example:

    >>> set_of_cluster_sets(['AAA', 'AGA', 'GGG'], [0, 0, 1])
    frozenset([frozenset(['AAA', 'AGA']), frozenset(['GGG'])])
  """
  all_clusters = collections.defaultdict(set)
  for sequence, cluster in zip(sequences, clusters):
    all_clusters[cluster].add(sequence)
  return frozenset(frozenset(seqs) for seqs in all_clusters.values())
